package  cn.gov.msa.pagehelper;

import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.session.RowBounds;

import com.github.pagehelper.Dialect;
import com.github.pagehelper.Page;
import com.github.pagehelper.SqlUtil;
import com.github.pagehelper.parser.Parser;
import com.github.pagehelper.parser.impl.AbstractParser;
import com.github.pagehelper.sqlsource.PageSqlSource;

public class CustomSqlUtil extends SqlUtil {
	  //具体针对数据库的parser
    private Parser parser;
    //缓存count查询的ms
    private static final Map<String, MappedStatement> msCountMap = new ConcurrentHashMap<String, MappedStatement>();
    public CustomSqlUtil(String strDialect) {
    	super(strDialect);
    }
	public CustomSqlUtil(String strDialect,String[] param) {
		super(strDialect);
		  if (strDialect == null || "".equals(strDialect)) {
	            throw new IllegalArgumentException("Mybatis分页插件无法获取dialect参数!");
	        }
	        Exception exception = null;
	        try {
	            Dialect dialect = Dialect.of(strDialect);
	            parser = AbstractParser.newParser(dialect);
	        } catch (Exception e) {
	            exception = e;
	            //异常的时候尝试反射，允许自己写实现类传递进来
	            try {
	                Class<?> parserClass = Class.forName(strDialect);
	                if (Parser.class.isAssignableFrom(parserClass)) {
	                    parser = (Parser) parserClass.newInstance();
	                    if(parser instanceof CustomOracleParse) {
	                    	((CustomOracleParse)parser).setParam(param);
	                    }
	                }
	            } catch (ClassNotFoundException ex) {
	                exception = ex;
	            } catch (InstantiationException ex) {
	                exception = ex;
	            } catch (IllegalAccessException ex) {
	                exception = ex;
	            }
	        }
	        if (parser == null) {
	            throw new RuntimeException(exception);
	        }
	}
    /**
     * 是否只做查询
     *
     * @param page
     * @return
     */
    private boolean isQueryOnly(Page page) {
        return page.isOrderByOnly()
                || ((page.getPageSizeZero() != null && page.getPageSizeZero()) && page.getPageSize() == 0);
    }
    /**
     * 只做查询
     *
     * @param page
     * @param invocation
     * @return
     * @throws Throwable
     */
    private Page doQueryOnly(Page page, Invocation invocation) throws Throwable {
        page.setCountSignal(null);
        //执行正常（不分页）查询
        Object result = invocation.proceed();
        //得到处理结果
        page.addAll((List) result);
        //相当于查询第一页
        page.setPageNum(1);
        //这种情况相当于pageSize=total
        page.setPageSize(page.size());
        //仍然要设置total
        page.setTotal(page.size());
        //返回结果仍然为Page类型 - 便于后面对接收类型的统一处理
        return page;
    }
	  /**
     * Mybatis拦截器方法
     *
     * @param invocation 拦截器入参
     * @return 返回执行结果
     * @throws Throwable 抛出异常
     */
    private Page doProcessPage(Invocation invocation, Page page, Object[] args) throws Throwable {
        //保存RowBounds状态
        RowBounds rowBounds = (RowBounds) args[2];
        //获取原始的ms
        MappedStatement ms = (MappedStatement) args[0];
        //判断并处理为PageSqlSource
        if (!isPageSqlSource(ms)) {
            processMappedStatement(ms);
        }
        //设置当前的parser，后面每次使用前都会set，ThreadLocal的值不会产生不良影响
        ((PageSqlSource)ms.getSqlSource()).setParser(parser);
        //忽略RowBounds-否则会进行Mybatis自带的内存分页
        args[2] = RowBounds.DEFAULT;
        //如果只进行排序 或 pageSizeZero的判断
        if (isQueryOnly(page)) {
            return doQueryOnly(page, invocation);
        }

        //简单的通过total的值来判断是否进行count查询
        if (page.isCount()) {
            page.setCountSignal(Boolean.TRUE);
            //替换MS
            args[0] = msCountMap.get(ms.getId());
            //查询总数
            Object result = invocation.proceed();
            //还原ms
            args[0] = ms;
            //设置总数
            page.setTotal((Integer) ((List) result).get(0));
            if (page.getTotal() == 0) {
                return page;
            }
        } else {
            page.setTotal(-1l);
        }
        //pageSize>0的时候执行分页查询，pageSize<=0的时候不执行相当于可能只返回了一个count
        if (page.getPageSize() > 0 &&
                ((rowBounds == RowBounds.DEFAULT && page.getPageNum() > 0)
                        || rowBounds != RowBounds.DEFAULT)) {
            //将参数中的MappedStatement替换为新的qs
            page.setCountSignal(null);
            BoundSql boundSql = ms.getBoundSql(args[1]);
            args[1] = parser.setPageParameter(ms, args[1], boundSql, page);
            page.setCountSignal(Boolean.FALSE);
            //执行分页查询
            Object result = invocation.proceed();
            //得到处理结果
            page.addAll((List) result);
        }
        //返回结果
        return page;
    }
   
}